// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <glog/logging.h>
#include <stddef.h>
#include <stdint.h>

#include <boost/iterator/iterator_facade.hpp>
#include <cmath>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>

#include "util/counts.h"
#include "util/tdigest.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_vector.h"
#include "vec/common/assert_cast.h"
#include "vec/common/pod_array_fwd.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"

namespace doris::vectorized {
#include "common/compile_check_begin.h"

class Arena;
class BufferReadable;

inline void check_quantile(double quantile) {
    if (quantile < 0 || quantile > 1) {
        throw Exception(ErrorCode::INVALID_ARGUMENT,
                        "quantile in func percentile should in [0, 1], but real data is:" +
                                std::to_string(quantile));
    }
}

struct PercentileApproxState {
    static constexpr double INIT_QUANTILE = -1.0;
    PercentileApproxState() = default;
    ~PercentileApproxState() = default;

    void init(double quantile, float compression = 10000) {
        if (!init_flag) {
            //https://doris.apache.org/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile_approx.html#description
            //The compression parameter setting range is [2048, 10000].
            //If the value of compression parameter is not specified set, or is outside the range of [2048, 10000],
            //will use the default value of 10000
            if (compression < 2048 || compression > 10000) {
                compression = 10000;
            }
            digest = TDigest::create_unique(compression);
            check_quantile(quantile);
            target_quantile = quantile;
            compressions = compression;
            init_flag = true;
        }
    }

    void write(BufferWritable& buf) const {
        buf.write_binary(init_flag);
        if (!init_flag) {
            return;
        }

        buf.write_binary(target_quantile);
        buf.write_binary(compressions);
        uint32_t serialize_size = digest->serialized_size();
        std::string result(serialize_size, '0');
        DCHECK(digest.get() != nullptr);
        digest->serialize((uint8_t*)result.c_str());

        buf.write_binary(result);
    }

    void read(BufferReadable& buf) {
        buf.read_binary(init_flag);
        if (!init_flag) {
            return;
        }

        buf.read_binary(target_quantile);
        buf.read_binary(compressions);
        std::string str;
        buf.read_binary(str);
        digest = TDigest::create_unique(compressions);
        digest->unserialize((uint8_t*)str.c_str());
    }

    double get() const {
        if (init_flag) {
            return digest->quantile(static_cast<float>(target_quantile));
        } else {
            return std::nan("");
        }
    }

    void merge(const PercentileApproxState& rhs) {
        if (!rhs.init_flag) {
            return;
        }
        if (init_flag) {
            DCHECK(digest.get() != nullptr);
            digest->merge(rhs.digest.get());
        } else {
            digest = TDigest::create_unique(compressions);
            digest->merge(rhs.digest.get());
            init_flag = true;
        }
        if (target_quantile == PercentileApproxState::INIT_QUANTILE) {
            target_quantile = rhs.target_quantile;
        }
    }

    void add(double source) { digest->add(static_cast<float>(source)); }

    void add_with_weight(double source, double weight) {
        // the weight should be positive num, as have check the value valid use DCHECK_GT(c._weight, 0);
        if (weight <= 0) {
            return;
        }
        digest->add(static_cast<float>(source), static_cast<float>(weight));
    }

    void reset() {
        target_quantile = INIT_QUANTILE;
        init_flag = false;
        digest = TDigest::create_unique(compressions);
    }

    bool init_flag = false;
    std::unique_ptr<TDigest> digest;
    double target_quantile = INIT_QUANTILE;
    float compressions = 10000;
};

class AggregateFunctionPercentileApprox
        : public IAggregateFunctionDataHelper<PercentileApproxState,
                                              AggregateFunctionPercentileApprox> {
public:
    AggregateFunctionPercentileApprox(const DataTypes& argument_types_)
            : IAggregateFunctionDataHelper<PercentileApproxState,
                                           AggregateFunctionPercentileApprox>(argument_types_) {}

    String get_name() const override { return "percentile_approx"; }

    void reset(AggregateDataPtr __restrict place) const override {
        AggregateFunctionPercentileApprox::data(place).reset();
    }

    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
               Arena&) const override {
        AggregateFunctionPercentileApprox::data(place).merge(
                AggregateFunctionPercentileApprox::data(rhs));
    }

    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
        AggregateFunctionPercentileApprox::data(place).write(buf);
    }

    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
                     Arena&) const override {
        AggregateFunctionPercentileApprox::data(place).read(buf);
    }
};

class AggregateFunctionPercentileApproxTwoParams final : public AggregateFunctionPercentileApprox,
                                                         public MultiExpression,
                                                         public NullableAggregateFunction {
public:
    AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_)
            : AggregateFunctionPercentileApprox(argument_types_) {}
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
             Arena&) const override {
        const auto& sources =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
        const auto& quantile =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
        this->data(place).init(quantile.get_element(0));
        this->data(place).add(sources.get_element(row_num));
    }

    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }

    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
        auto& col = assert_cast<ColumnFloat64&>(to);
        double result = AggregateFunctionPercentileApprox::data(place).get();

        if (std::isnan(result)) {
            col.insert_default();
        } else {
            col.get_data().push_back(result);
        }
    }
};

class AggregateFunctionPercentileApproxThreeParams final : public AggregateFunctionPercentileApprox,
                                                           public MultiExpression,
                                                           public NullableAggregateFunction {
public:
    AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_)
            : AggregateFunctionPercentileApprox(argument_types_) {}
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
             Arena&) const override {
        const auto& sources =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
        const auto& quantile =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
        const auto& compression =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);

        this->data(place).init(quantile.get_element(0),
                               static_cast<float>(compression.get_element(0)));
        this->data(place).add(sources.get_element(row_num));
    }

    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }

    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
        auto& col = assert_cast<ColumnFloat64&>(to);
        double result = AggregateFunctionPercentileApprox::data(place).get();

        if (std::isnan(result)) {
            col.insert_default();
        } else {
            col.get_data().push_back(result);
        }
    }
};

class AggregateFunctionPercentileApproxWeightedThreeParams final
        : public AggregateFunctionPercentileApprox,
          MultiExpression,
          NullableAggregateFunction {
public:
    AggregateFunctionPercentileApproxWeightedThreeParams(const DataTypes& argument_types_)
            : AggregateFunctionPercentileApprox(argument_types_) {}

    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
             Arena&) const override {
        const auto& sources =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
        const auto& weight =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
        const auto& quantile =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);

        this->data(place).init(quantile.get_element(0));
        this->data(place).add_with_weight(sources.get_element(row_num),
                                          weight.get_element(row_num));
    }

    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }

    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
        auto& col = assert_cast<ColumnFloat64&>(to);
        double result = AggregateFunctionPercentileApprox::data(place).get();

        if (std::isnan(result)) {
            col.insert_default();
        } else {
            col.get_data().push_back(result);
        }
    }
};

class AggregateFunctionPercentileApproxWeightedFourParams final
        : public AggregateFunctionPercentileApprox,
          MultiExpression,
          NullableAggregateFunction {
public:
    AggregateFunctionPercentileApproxWeightedFourParams(const DataTypes& argument_types_)
            : AggregateFunctionPercentileApprox(argument_types_) {}
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
             Arena&) const override {
        const auto& sources =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0]);
        const auto& weight =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
        const auto& quantile =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2]);
        const auto& compression =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[3]);

        this->data(place).init(quantile.get_element(0),
                               static_cast<float>(compression.get_element(0)));
        this->data(place).add_with_weight(sources.get_element(row_num),
                                          weight.get_element(row_num));
    }

    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }

    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
        auto& col = assert_cast<ColumnFloat64&>(to);
        double result = AggregateFunctionPercentileApprox::data(place).get();

        if (std::isnan(result)) {
            col.insert_default();
        } else {
            col.get_data().push_back(result);
        }
    }
};

template <PrimitiveType T>
struct PercentileState {
    mutable std::vector<Counts<typename PrimitiveTypeTraits<T>::ColumnItemType>> vec_counts;
    std::vector<double> vec_quantile {-1};
    bool inited_flag = false;

    void write(BufferWritable& buf) const {
        buf.write_binary(inited_flag);
        if (!inited_flag) {
            return;
        }
        int size_num = cast_set<int>(vec_quantile.size());
        buf.write_binary(size_num);
        for (const auto& quantile : vec_quantile) {
            buf.write_binary(quantile);
        }
        for (auto& counts : vec_counts) {
            counts.serialize(buf);
        }
    }

    void read(BufferReadable& buf) {
        buf.read_binary(inited_flag);
        if (!inited_flag) {
            return;
        }
        int size_num = 0;
        buf.read_binary(size_num);
        double data = 0.0;
        vec_quantile.clear();
        for (int i = 0; i < size_num; ++i) {
            buf.read_binary(data);
            vec_quantile.emplace_back(data);
        }
        vec_counts.clear();
        vec_counts.resize(size_num);
        for (int i = 0; i < size_num; ++i) {
            vec_counts[i].unserialize(buf);
        }
    }

    void add(typename PrimitiveTypeTraits<T>::ColumnItemType source,
             const PaddedPODArray<Float64>& quantiles, const NullMap& null_maps, int64_t arg_size) {
        if (!inited_flag) {
            vec_counts.resize(arg_size);
            vec_quantile.resize(arg_size, -1);
            inited_flag = true;
            for (int i = 0; i < arg_size; ++i) {
                // throw Exception func call percentile_array(id, [1,0,null])
                if (null_maps[i]) {
                    throw Exception(ErrorCode::INVALID_ARGUMENT,
                                    "quantiles in func percentile_array should not have null");
                }
                check_quantile(quantiles[i]);
                vec_quantile[i] = quantiles[i];
            }
        }
        for (int i = 0; i < arg_size; ++i) {
            vec_counts[i].increment(source);
        }
    }

    void add_batch(const PaddedPODArray<typename PrimitiveTypeTraits<T>::ColumnItemType>& source,
                   const Float64& q) {
        if (!inited_flag) {
            inited_flag = true;
            vec_counts.resize(1);
            vec_quantile.resize(1);
            check_quantile(q);
            vec_quantile[0] = q;
        }
        vec_counts[0].increment_batch(source);
    }

    void merge(const PercentileState& rhs) {
        if (!rhs.inited_flag) {
            return;
        }
        int size_num = cast_set<int>(rhs.vec_quantile.size());
        if (!inited_flag) {
            vec_counts.resize(size_num);
            vec_quantile.resize(size_num, -1);
            inited_flag = true;
        }

        for (int i = 0; i < size_num; ++i) {
            if (vec_quantile[i] == -1.0) {
                vec_quantile[i] = rhs.vec_quantile[i];
            }
            vec_counts[i].merge(&(rhs.vec_counts[i]));
        }
    }

    void reset() {
        vec_counts.clear();
        vec_quantile.clear();
        inited_flag = false;
    }

    double get() const { return vec_counts.empty() ? 0 : vec_counts[0].terminate(vec_quantile[0]); }

    void insert_result_into(IColumn& to) const {
        auto& column_data = assert_cast<ColumnFloat64&>(to).get_data();
        for (int i = 0; i < vec_counts.size(); ++i) {
            column_data.push_back(vec_counts[i].terminate(vec_quantile[i]));
        }
    }
};

template <PrimitiveType T>
class AggregateFunctionPercentile final
        : public IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>>,
          MultiExpression,
          NullableAggregateFunction {
public:
    using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
    using Base = IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentile<T>>;
    AggregateFunctionPercentile(const DataTypes& argument_types_) : Base(argument_types_) {}

    String get_name() const override { return "percentile"; }

    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }

    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
             Arena&) const override {
        const auto& sources =
                assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
        const auto& quantile =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
        AggregateFunctionPercentile::data(place).add(sources.get_data()[row_num],
                                                     quantile.get_data(), NullMap(1, 0), 1);
    }

    void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
                                Arena&) const override {
        const auto& sources =
                assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
        const auto& quantile =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
        DCHECK_EQ(sources.get_data().size(), batch_size);
        AggregateFunctionPercentile::data(place).add_batch(sources.get_data(),
                                                           quantile.get_data()[0]);
    }

    void reset(AggregateDataPtr __restrict place) const override {
        AggregateFunctionPercentile::data(place).reset();
    }

    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
               Arena&) const override {
        AggregateFunctionPercentile::data(place).merge(AggregateFunctionPercentile::data(rhs));
    }

    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
        AggregateFunctionPercentile::data(place).write(buf);
    }

    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
                     Arena&) const override {
        AggregateFunctionPercentile::data(place).read(buf);
    }

    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
        auto& col = assert_cast<ColumnFloat64&>(to);
        col.insert_value(AggregateFunctionPercentile::data(place).get());
    }
};

template <PrimitiveType T>
class AggregateFunctionPercentileArray final
        : public IAggregateFunctionDataHelper<PercentileState<T>,
                                              AggregateFunctionPercentileArray<T>>,
          MultiExpression,
          NotNullableAggregateFunction {
public:
    using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
    using Base =
            IAggregateFunctionDataHelper<PercentileState<T>, AggregateFunctionPercentileArray<T>>;
    AggregateFunctionPercentileArray(const DataTypes& argument_types_) : Base(argument_types_) {}

    String get_name() const override { return "percentile_array"; }

    DataTypePtr get_return_type() const override {
        return std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat64>()));
    }

    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
             Arena&) const override {
        const auto& sources =
                assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
        const auto& quantile_array =
                assert_cast<const ColumnArray&, TypeCheckOnRelease::DISABLE>(*columns[1]);
        const auto& offset_column_data = quantile_array.get_offsets();
        const auto& null_maps = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
                                        quantile_array.get_data())
                                        .get_null_map_data();
        const auto& nested_column = assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(
                                            quantile_array.get_data())
                                            .get_nested_column();
        const auto& nested_column_data =
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(nested_column);

        AggregateFunctionPercentileArray::data(place).add(
                sources.get_element(row_num), nested_column_data.get_data(), null_maps,
                offset_column_data.data()[row_num] - offset_column_data[(ssize_t)row_num - 1]);
    }

    void reset(AggregateDataPtr __restrict place) const override {
        AggregateFunctionPercentileArray::data(place).reset();
    }

    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
               Arena&) const override {
        AggregateFunctionPercentileArray::data(place).merge(
                AggregateFunctionPercentileArray::data(rhs));
    }

    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
        AggregateFunctionPercentileArray::data(place).write(buf);
    }

    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
                     Arena&) const override {
        AggregateFunctionPercentileArray::data(place).read(buf);
    }

    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
        auto& to_arr = assert_cast<ColumnArray&>(to);
        auto& to_nested_col = to_arr.get_data();
        if (to_nested_col.is_nullable()) {
            auto col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col);
            AggregateFunctionPercentileArray::data(place).insert_result_into(
                    col_null->get_nested_column());
            col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(), 0);
        } else {
            AggregateFunctionPercentileArray::data(place).insert_result_into(to_nested_col);
        }
        to_arr.get_offsets().push_back(to_nested_col.size());
    }
};

#include "common/compile_check_end.h"
} // namespace doris::vectorized